#!/usr/bin/env python3
"""
Phase 10: Three missing analyses from the original paper

1. Post-tariff residual monitoring (Figure 5 equivalent)
2. Country-specific KAOPEN marginal effects table (Table 9 equivalent)
3. Country case study figure (Figure 2 equivalent)

These require re-estimating the baseline and extended models to get
model objects (beta, residuals, fitted values).
"""

import sys
import os
import pandas as pd
import numpy as np
from pathlib import Path

FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
PROJECT_DIR = FOLLOWUP_DIR.parent
sys.path.insert(0, str(FOLLOWUP_DIR))
sys.path.insert(1, str(PROJECT_DIR))

from src.model import PanelGLS, estimate_baseline_model
from src.macro import filter_eba_sample

OUTPUT_DIR = FOLLOWUP_DIR / "output"
TABLE_DIR = OUTPUT_DIR / "tables"
FIG_DIR = OUTPUT_DIR / "figures"
PAPER_FIG_DIR = FOLLOWUP_DIR / "paper" / "figures"

for d in [TABLE_DIR, FIG_DIR, PAPER_FIG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# ─── Load data ───────────────────────────────────────────────────────────
panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
polys = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "demographic_polynomials.csv")

est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

# Full 140-country sample
full_sample = filter_eba_sample(est, extended=True, expansion=True)
print(f"Full sample: {full_sample['iso3'].nunique()} countries, {len(full_sample):,} obs")


# ═══════════════════════════════════════════════════════════════════════════
# RE-ESTIMATE BASELINE AND EXTENDED MODELS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("RE-ESTIMATING MODELS FOR RESIDUALS AND MARGINAL EFFECTS")
print("=" * 70)

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp', 'life_expectancy']
extended_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp', 'log_lending_rate']
interaction_vars = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']

# --- Baseline model ---
baseline_features = demo_vars + baseline_controls
df_base = full_sample.dropna(subset=['ca_gdp'] + baseline_features).copy()
# Filter to 2000-2019 like the pipeline
df_base = df_base[(df_base['year'] >= 2000) & (df_base['year'] <= 2019)]
print(f"\nBaseline sample: {df_base['iso3'].nunique()} countries, {len(df_base)} obs")

gls_base = PanelGLS()
y_base = df_base['ca_gdp'].values
X_base = df_base[baseline_features].values
gls_base.fit(y_base, X_base, df_base['iso3'].values, df_base['year'].values)
gls_base.feature_names = baseline_features

df_base['resid_baseline'] = gls_base.resid
df_base['fitted_baseline'] = gls_base.fitted

print(f"  R² = {gls_base.r_squared:.4f}")
print(f"  Z₁ = {gls_base.beta[0]:.2f} (p={gls_base.pvalues[0]:.4f})")

# Save updated baseline regression CSV
base_csv = pd.DataFrame({
    'variable': baseline_features,
    'coefficient': gls_base.beta,
    'std_error': gls_base.se,
    't_stat': gls_base.beta / gls_base.se,
    'p_value': gls_base.pvalues
})
base_csv.to_csv(TABLE_DIR / "regression_baseline_demo_plus_eba_140.csv", index=False)
print("  Saved 140-country baseline regression CSV")

# --- Extended model ---
for v in interaction_vars:
    if v not in full_sample.columns:
        base, inter = v.rsplit('_x_', 1)
        full_sample[v] = full_sample[base] * full_sample[inter]

extended_features = demo_vars + extended_controls + interaction_vars
df_ext = full_sample.dropna(subset=['ca_gdp'] + extended_features).copy()
df_ext = df_ext[(df_ext['year'] >= 2000) & (df_ext['year'] <= 2019)]
print(f"\nExtended sample: {df_ext['iso3'].nunique()} countries, {len(df_ext)} obs")

gls_ext = PanelGLS()
y_ext = df_ext['ca_gdp'].values
X_ext = df_ext[extended_features].values
gls_ext.fit(y_ext, X_ext, df_ext['iso3'].values, df_ext['year'].values)
gls_ext.feature_names = extended_features
print(f"  R² = {gls_ext.r_squared:.4f}")
print(f"  Z₁×KAOPEN = {gls_ext.beta[extended_features.index('Z_1_x_kaopen')]:.2f} "
      f"(p={gls_ext.pvalues[extended_features.index('Z_1_x_kaopen')]:.4f})")

# Save updated extended regression CSV
ext_csv = pd.DataFrame({
    'variable': extended_features,
    'coefficient': gls_ext.beta,
    'std_error': gls_ext.se,
    't_stat': gls_ext.beta / gls_ext.se,
    'p_value': gls_ext.pvalues
})
ext_csv.to_csv(TABLE_DIR / "regression_extended_plus_rates_140.csv", index=False)
ext_csv.to_csv(TABLE_DIR / "regression_extended_plus_interactions_140.csv", index=False)
print("  Saved 140-country extended regression CSVs")


# ═══════════════════════════════════════════════════════════════════════════
# ANALYSIS 1: POST-TARIFF RESIDUAL MONITORING
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("ANALYSIS 1: POST-TARIFF RESIDUAL MONITORING")
print("=" * 70)

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

focus_countries = ['USA', 'CHN', 'DEU', 'JPN', 'KOR', 'MEX', 'CAN']
tariff_year = 2018

results_tariff = []
yearly_tariff = {}

for iso3 in focus_countries:
    cdf = df_base[df_base['iso3'] == iso3].sort_values('year')
    if len(cdf) == 0:
        print(f"  {iso3}: no data in baseline sample")
        continue

    pre = cdf[cdf['year'] < tariff_year]
    post = cdf[cdf['year'] >= tariff_year]

    row = {
        'iso3': iso3,
        'pre_mean_resid': pre['resid_baseline'].mean() if len(pre) > 0 else np.nan,
        'post_mean_resid': post['resid_baseline'].mean() if len(post) > 0 else np.nan,
        'pre_n': len(pre),
        'post_n': len(post),
    }
    row['resid_shift'] = row['post_mean_resid'] - row['pre_mean_resid']
    results_tariff.append(row)

    yearly_tariff[iso3] = cdf[['year', 'resid_baseline', 'ca_gdp', 'fitted_baseline']].copy()

tariff_df = pd.DataFrame(results_tariff)
print(f"\n{'iso3':>5} {'pre_mean':>10} {'post_mean':>10} {'shift':>10} {'pre_n':>6} {'post_n':>6}")
print("-" * 55)
for _, r in tariff_df.iterrows():
    print(f"{r['iso3']:>5} {r['pre_mean_resid']:>10.3f} {r['post_mean_resid']:>10.3f} "
          f"{r['resid_shift']:>10.3f} {int(r['pre_n']):>6} {int(r['post_n']):>6}")

print("\nPositive shift = CA higher than model predicts post-tariffs")
print("Negative shift = CA lower than model predicts post-tariffs")

tariff_df.to_csv(TABLE_DIR / "tariff_residual_shift.csv", index=False)

# --- Plot tariff residuals (Figure 5 equivalent) ---
available = [c for c in focus_countries if c in yearly_tariff]
n = len(available)
n_cols = min(3, n)
n_rows = (n + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), squeeze=False)

for idx, iso3 in enumerate(available):
    row, col = idx // n_cols, idx % n_cols
    ax = axes[row, col]
    cdf = yearly_tariff[iso3]

    colors = ['#1f77b4' if y < tariff_year else '#d62728' for y in cdf['year']]
    ax.bar(cdf['year'], cdf['resid_baseline'], color=colors, alpha=0.7, edgecolor='none')
    ax.axhline(y=0, color='black', linewidth=0.8)
    ax.axvline(x=tariff_year - 0.5, color='red', linewidth=1.5, linestyle='--', alpha=0.7)

    pre = cdf[cdf['year'] < tariff_year]['resid_baseline']
    post = cdf[cdf['year'] >= tariff_year]['resid_baseline']
    if len(pre) > 0 and len(post) > 0:
        shift = post.mean() - pre.mean()
        ax.text(0.98, 0.02, f"Shift: {shift:+.2f} pp",
                transform=ax.transAxes, fontsize=9, ha='right', va='bottom',
                bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    ax.set_title(iso3, fontsize=12, fontweight='bold')
    ax.set_ylabel('Residual (pp CA/GDP)', fontsize=9)
    ax.grid(axis='y', alpha=0.3)

for idx in range(n, n_rows * n_cols):
    axes[idx // n_cols, idx % n_cols].set_visible(False)

plt.suptitle('Model Residuals Before & After Tariffs (2018)', fontsize=14, y=1.02)
plt.tight_layout()
fig.savefig(FIG_DIR / "fig5_tariff_residuals.png", dpi=150, bbox_inches='tight')
fig.savefig(PAPER_FIG_DIR / "fig5_tariff_residuals.png", dpi=150, bbox_inches='tight')
plt.close()
print(f"\nSaved fig5_tariff_residuals.png")


# ═══════════════════════════════════════════════════════════════════════════
# ANALYSIS 2: COUNTRY-SPECIFIC KAOPEN MARGINAL EFFECTS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("ANALYSIS 2: COUNTRY-SPECIFIC KAOPEN MARGINAL EFFECTS")
print("=" * 70)

# Extract interaction coefficients from the extended model
coeff_map = dict(zip(extended_features, gls_ext.beta))
coeffs = {
    'beta_kaopen': coeff_map.get('kaopen', 0),
    'delta_1': coeff_map.get('Z_1_x_kaopen', 0),
    'delta_2': coeff_map.get('Z_2_x_kaopen', 0),
    'delta_3': coeff_map.get('Z_3_x_kaopen', 0),
}
print(f"  Interaction coefficients:")
print(f"    beta_kaopen = {coeffs['beta_kaopen']:.4f}")
print(f"    delta_1 (Z₁×K) = {coeffs['delta_1']:.4f}")
print(f"    delta_2 (Z₂×K) = {coeffs['delta_2']:.4f}")
print(f"    delta_3 (Z₃×K) = {coeffs['delta_3']:.4f}")


def marginal_effect(z1, z2, z3, c):
    return c['beta_kaopen'] + c['delta_1'] * z1 + c['delta_2'] * z2 + c['delta_3'] * z3


# Current demographics: latest historical year (≤2024) with Z data
hist = panel[panel['year'] <= 2024].dropna(subset=['Z_1', 'Z_2', 'Z_3'])
current = hist.sort_values('year').groupby('iso3').last().reset_index()

rows_me = []
for _, r in current.iterrows():
    iso = r['iso3']
    me_current = marginal_effect(r['Z_1'], r['Z_2'], r['Z_3'], coeffs)
    row = {
        'iso3': iso,
        'current_year': int(r['year']),
        'current_kaopen': r.get('kaopen', np.nan),
        'marginal_effect_current': me_current,
    }

    for proj_year in [2030, 2040, 2050]:
        proj = polys[(polys['iso3'] == iso) & (polys['year'] == proj_year)]
        if len(proj) > 0:
            p = proj.iloc[0]
            me = marginal_effect(p['Z_1'], p['Z_2'], p['Z_3'], coeffs)
            row[f'marginal_effect_{proj_year}'] = me
        else:
            row[f'marginal_effect_{proj_year}'] = np.nan

    rows_me.append(row)

me_df = pd.DataFrame(rows_me).sort_values('marginal_effect_current', ascending=False)
me_df.to_csv(TABLE_DIR / "openness_marginal_effects.csv", index=False)

print(f"\n  Marginal effects computed for {len(me_df)} countries")
print(f"\n  Top 10 (largest positive dCA/dKAOPEN — most to gain from openness):")
for _, r in me_df.head(10).iterrows():
    print(f"    {r['iso3']:>5}: {r['marginal_effect_current']:>+7.2f} (current)  "
          f"{r.get('marginal_effect_2030', np.nan):>+7.2f} (2030)  "
          f"{r.get('marginal_effect_2050', np.nan):>+7.2f} (2050)")

print(f"\n  Bottom 10 (largest negative — most to lose from openness):")
for _, r in me_df.tail(10).iterrows():
    print(f"    {r['iso3']:>5}: {r['marginal_effect_current']:>+7.2f} (current)  "
          f"{r.get('marginal_effect_2030', np.nan):>+7.2f} (2030)  "
          f"{r.get('marginal_effect_2050', np.nan):>+7.2f} (2050)")

# Focus countries table (for paper)
focus_me = ['JPN', 'DEU', 'USA', 'KOR', 'CHN', 'IND', 'BRA', 'NGA', 'IRN', 'VNM', 'BGD', 'GTM']
me_focus = me_df[me_df['iso3'].isin(focus_me)].copy()
# Reorder to match focus_me ordering
me_focus['_order'] = me_focus['iso3'].map({c: i for i, c in enumerate(focus_me)})
me_focus = me_focus.sort_values('_order').drop(columns='_order')

print(f"\n  Focus countries marginal effects (dCA/dKAOPEN, pp per unit):")
print(f"  {'Country':>5}  {'KAOPEN':>7}  {'Current':>8}  {'2030':>8}  {'2040':>8}  {'2050':>8}")
print(f"  {'-'*50}")
for _, r in me_focus.iterrows():
    print(f"  {r['iso3']:>5}  {r['current_kaopen']:>7.2f}  "
          f"{r['marginal_effect_current']:>+8.2f}  "
          f"{r.get('marginal_effect_2030', np.nan):>+8.2f}  "
          f"{r.get('marginal_effect_2040', np.nan):>+8.2f}  "
          f"{r.get('marginal_effect_2050', np.nan):>+8.2f}")

# Save markdown table for paper
with open(TABLE_DIR / "table_marginal_effects_focus.md", 'w') as f:
    f.write("# Marginal Effect of Financial Openness (dCA/dKAOPEN)\n\n")
    f.write("The marginal effect of a one-unit increase in KAOPEN on CA/GDP,\n")
    f.write("evaluated at each country's demographics (current and projected).\n\n")
    f.write("| Country | KAOPEN | Current | 2030 | 2040 | 2050 |\n")
    f.write("|:--------|------:|-------:|-----:|-----:|-----:|\n")
    for _, r in me_focus.iterrows():
        f.write(f"| {r['iso3']} | {r['current_kaopen']:.2f} | "
                f"{r['marginal_effect_current']:+.2f} | "
                f"{r.get('marginal_effect_2030', np.nan):+.2f} | "
                f"{r.get('marginal_effect_2040', np.nan):+.2f} | "
                f"{r.get('marginal_effect_2050', np.nan):+.2f} |\n")
    f.write(f"\n*Positive = further opening improves CA; negative = further opening worsens CA.*\n")
    f.write(f"*Based on extended model: dCA/dKAOPEN = {coeffs['beta_kaopen']:.2f} "
            f"+ {coeffs['delta_1']:.2f}·Z₁ + ({coeffs['delta_2']:.2f})·Z₂ + {coeffs['delta_3']:.3f}·Z₃*\n")

print(f"\nSaved table_marginal_effects_focus.md")


# ═══════════════════════════════════════════════════════════════════════════
# ANALYSIS 3: COUNTRY CASE STUDY FIGURE
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("ANALYSIS 3: COUNTRY CASE STUDY FIGURE")
print("=" * 70)

# Compute demographic contributions using baseline model coefficients
z_betas = gls_base.beta[:3]  # Z_1, Z_2, Z_3
print(f"  Z betas: Z₁={z_betas[0]:.2f}, Z₂={z_betas[1]:.4f}, Z₃={z_betas[2]:.4f}")

case_countries = ['CHN', 'IND', 'IDN', 'JPN', 'USA', 'DEU',
                  'BRA', 'NGA', 'ZAF', 'KOR', 'GBR', 'IRN']

profiles = {}
for country in case_countries:
    cdf = panel[panel['iso3'] == country].copy()
    if len(cdf) == 0:
        print(f"  {country}: no data")
        continue

    cdf = cdf.sort_values('year')
    cdf['demo_contribution'] = (z_betas[0] * cdf['Z_1'] +
                                z_betas[1] * cdf['Z_2'] +
                                z_betas[2] * cdf['Z_3'])

    for i, z in enumerate(['Z_1', 'Z_2', 'Z_3']):
        cdf[f'contribution_{z}'] = z_betas[i] * cdf[z]

    profiles[country] = cdf
    n_hist = cdf[cdf['year'] <= 2024].shape[0]
    n_proj = cdf[cdf['year'] > 2024].shape[0]
    print(f"  {country}: {n_hist} historical + {n_proj} projected years")

# --- Plot Figure 2 equivalent ---
countries_to_plot = [c for c in case_countries if c in profiles]
n_countries = len(countries_to_plot)
n_cols = min(3, n_countries)
n_rows = (n_countries + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows),
                         sharex=False, sharey=False)
if n_rows == 1 and n_cols == 1:
    axes = np.array([[axes]])
elif n_rows == 1:
    axes = axes.reshape(1, -1)
elif n_cols == 1:
    axes = axes.reshape(-1, 1)

for idx, country in enumerate(countries_to_plot):
    row, col = idx // n_cols, idx % n_cols
    ax = axes[row, col]

    cdf = profiles[country].sort_values('year')

    # Historical vs projected split
    hist = cdf[cdf['year'] <= 2024]
    proj = cdf[cdf['year'] > 2024]

    # Plot demographic contribution (historical solid, projected dashed)
    ax.plot(hist['year'], hist['demo_contribution'], 'b-', linewidth=2,
            label='Demo. contribution')
    if len(proj) > 0:
        # Connect historical and projected
        connect = pd.concat([hist.tail(1), proj])
        ax.plot(connect['year'], connect['demo_contribution'], 'b--', linewidth=1.5,
                alpha=0.6, label='Projected')

    # Plot actual CA/GDP (only historical)
    if 'ca_gdp' in hist.columns:
        ca_hist = hist.dropna(subset=['ca_gdp'])
        if len(ca_hist) > 0:
            ax.plot(ca_hist['year'], ca_hist['ca_gdp'], 'k-', linewidth=1,
                    alpha=0.4, label='Actual CA/GDP')

    ax.axhline(y=0, color='gray', linewidth=0.5, linestyle='-')
    ax.axvline(x=2024.5, color='gray', linewidth=0.5, linestyle=':', alpha=0.5)
    ax.set_title(country, fontsize=12, fontweight='bold')
    ax.set_xlabel('Year', fontsize=9)
    ax.set_ylabel('CA/GDP (pp)', fontsize=9)
    ax.grid(alpha=0.3)
    ax.legend(fontsize=7, loc='best')

# Hide unused subplots
for idx in range(n_countries, n_rows * n_cols):
    row, col = idx // n_cols, idx % n_cols
    axes[row, col].set_visible(False)

plt.suptitle('Demographic Contribution to Current Account / GDP (140-country model)',
             fontsize=14, y=1.02)
plt.tight_layout()
fig.savefig(FIG_DIR / "fig2_demographic_contributions.png", dpi=150, bbox_inches='tight')
fig.savefig(PAPER_FIG_DIR / "fig2_demographic_contributions.png", dpi=150, bbox_inches='tight')
plt.close()
print(f"\nSaved fig2_demographic_contributions.png")

# --- Save country fit statistics ---
fit_stats = []
for country in countries_to_plot:
    cdf = profiles[country]
    hist = cdf[(cdf['year'] >= 2000) & (cdf['year'] <= 2019)].dropna(subset=['ca_gdp'])
    if len(hist) == 0:
        continue
    corr = hist['demo_contribution'].corr(hist['ca_gdp'])
    mean_ca = hist['ca_gdp'].mean()
    mean_demo = hist['demo_contribution'].mean()
    fit_stats.append({
        'iso3': country,
        'mean_ca_gdp': mean_ca,
        'mean_demo_contribution': mean_demo,
        'correlation': corr,
        'n_obs': len(hist),
    })

fit_df = pd.DataFrame(fit_stats)
fit_df.to_csv(TABLE_DIR / "country_case_study_fit.csv", index=False)

print(f"\n  Country case study fit (2000-2019):")
print(f"  {'Country':>5}  {'Mean CA':>8}  {'Mean Demo':>10}  {'Corr':>6}  {'N':>4}")
print(f"  {'-'*40}")
for _, r in fit_df.iterrows():
    print(f"  {r['iso3']:>5}  {r['mean_ca_gdp']:>+8.2f}  {r['mean_demo_contribution']:>+10.2f}  "
          f"{r['correlation']:>6.3f}  {int(r['n_obs']):>4}")


# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("PHASE 10 COMPLETE")
print("=" * 70)
print(f"\nOutputs:")
print(f"  Tables:")
print(f"    {TABLE_DIR / 'tariff_residual_shift.csv'}")
print(f"    {TABLE_DIR / 'openness_marginal_effects.csv'}")
print(f"    {TABLE_DIR / 'table_marginal_effects_focus.md'}")
print(f"    {TABLE_DIR / 'country_case_study_fit.csv'}")
print(f"    {TABLE_DIR / 'regression_baseline_demo_plus_eba_140.csv'}")
print(f"    {TABLE_DIR / 'regression_extended_plus_rates_140.csv'}")
print(f"  Figures:")
print(f"    {FIG_DIR / 'fig5_tariff_residuals.png'}")
print(f"    {FIG_DIR / 'fig2_demographic_contributions.png'}")
print(f"  Paper figures:")
print(f"    {PAPER_FIG_DIR / 'fig5_tariff_residuals.png'}")
print(f"    {PAPER_FIG_DIR / 'fig2_demographic_contributions.png'}")
print("\nDone.")
